package com.gitee.fastmybatis.core.support.datasource;

import com.gitee.fastmybatis.core.util.IOUtil;

import javax.sql.DataSource;
import java.io.IOException;
import java.io.PrintWriter;
import java.sql.*;
import java.util.*;
import java.util.logging.Logger;

/**
 * @author tanghc
 */
public class H2MemDataSourceBuilder {

    private String driverClassName = "org.h2.Driver";
    private String url = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1";
    private String username = "";
    private String password = "";

    private List<String> scripts = Collections.emptyList();

    public H2MemDataSourceBuilder driverClassName(String driverClassName) {
        this.driverClassName = driverClassName;
        return this;
    }

    public H2MemDataSourceBuilder url(String url) {
        this.url = url;
        return this;
    }

    public H2MemDataSourceBuilder addScript(String... script) {
        if (scripts.isEmpty()) {
            scripts = new ArrayList<>(8);
        }
        scripts.addAll(Arrays.asList(script));
        return this;
    }

    public H2MemDataSourceBuilder account(String username, String password) {
        this.username = username;
        this.password = password;
        return this;
    }

    public DataSource build() {
        try {
            Class.forName(driverClassName);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("driverClassNotFound:" + driverClassName, e);
        }
        try {
            Driver driver = DriverManager.getDriver(url);
            H2MemDataSource h2MemDataSource = new H2MemDataSource(driver, url, username, password);
            runScript(h2MemDataSource);
            return h2MemDataSource;
        } catch (SQLException | IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void runScript(DataSource dataSource) throws SQLException, IOException {
        if (scripts.isEmpty()) {
            return;
        }
        Connection connection = dataSource.getConnection();
        connection.setAutoCommit(false);

        for (String script : scripts) {
            List<IOUtil.ResourceFile> resourceFiles = IOUtil.listJarFiles(script, "sql");
            for (IOUtil.ResourceFile resourceFile : resourceFiles) {
                String sql = resourceFile.getContent();
                PreparedStatement statement = connection.prepareStatement(sql);
                statement.executeUpdate();
                statement.close();
            }
        }
        connection.commit();
        connection.close();
    }


    public static class H2MemDataSource implements DataSource {

        private final Driver driver;
        private final String url;

        private final Properties properties;

        public H2MemDataSource(Driver driver, String url, String username, String password) {
            this.properties = new Properties();
            this.driver = driver;
            this.url = url;
            properties.put("user", username);
            properties.put("password", password);
        }

        @Override
        public Connection getConnection() throws SQLException {
            return driver.connect(url, properties);
        }

        @Override
        public Connection getConnection(String username, String password) throws SQLException {
            return getConnection();
        }

        @Override
        @SuppressWarnings("unchecked")
        public <T> T unwrap(Class<T> iface) throws SQLException {
            if (iface.isInstance(this)) {
                return (T) this;
            }
            throw new SQLException("DataSource of type [" + getClass().getName() +
                    "] cannot be unwrapped as [" + iface.getName() + "]");
        }

        @Override
        public boolean isWrapperFor(Class<?> iface) throws SQLException {
            return iface.isInstance(this);
        }

        @Override
        public PrintWriter getLogWriter() throws SQLException {
            throw new UnsupportedOperationException("getLogWriter");
        }

        @Override
        public void setLogWriter(PrintWriter out) throws SQLException {

        }

        @Override
        public void setLoginTimeout(int seconds) throws SQLException {

        }

        @Override
        public int getLoginTimeout() throws SQLException {
            return 0;
        }

        @Override
        public Logger getParentLogger() throws SQLFeatureNotSupportedException {
            return Logger.getLogger(Logger.GLOBAL_LOGGER_NAME);
        }
    }


}
